导航菜单

深度学习

作者:京东零售 刘岩扩散模型讲解前沿

人工智能生成内容(AI Generated Content,AIGC)近年来成为了非常前沿的一个研究方向,生成模型目前有四个流派,分别是生成对抗网络(Generative Adversarial Models,GAN),变分自编码器(Variance Auto-Encoder,VAE),标准化流模型(Normalization Flow, NF)以及这里要介绍的扩散模型(Diffusion Models,DM)。扩散模型是受到热力学中的一个分支,它的思想来源是非平衡热力学(Non-equilibrium thermodynamics)。扩散模型的算法理论基础是通过变分推断(Variational Inference)训练参数化的马尔可夫链(Markov Chain),它在许多任务上展现了超过GAN等其它生成模型的效果,例如最近非常火热的OpenAI的DALL-E 2,Stability.ai的Stable Diffusion等。这些效果惊艳的模型扩散模型的理论基础便是我们这里要介绍的提出扩散模型的文章[1]和非常重要的DDPM[2],扩散模型的实现并不复杂,但其背后的数学原理却非常丰富。在这里我会介绍这些重要的数学原理,但省去了这些公式的推导计算,如果你对这些推导感兴趣,可以学习参考文献[4,5,11]的相关内容。我在这里主要以一个相对简单的角度来讲解扩散模型,帮助你快速入门这个非常重要的生成算法。

1. 背景知识: 生成模型

目前生成模型主要有图1所示的四类。其中GAN的原理是通过判别器和生成器的互相博弈来让生成器生成足以以假乱真的图像。VAE的原理是通过一个编码器将输入图像编码成特征向量,它用来学习高斯分布的均值和方差,而解码器则可以将特征向量转化为生成图像,它侧重于学习生成能力。流模型是从一个简单的分布开始,通过一系列可逆的转换函数将分布转化成目标分布。扩散模型先通过正向过程将噪声逐渐加入到数据中,然后通过反向过程预测每一步加入的噪声,通过将噪声去掉的方式逐渐还原得到无噪声的图像,扩散模型本质上是一个马尔可夫架构,只是其中训练过程用到了深度学习的BP,但它更属于数学层面的创新。这也就是为什么很多计算机的同学看扩散模型相关的论文会如此费力。

图1:生成模型的四种类型 [4]

扩散模型中最重要的思想根基是马尔可夫链,它的一个关键性质是平稳性。即如果一个概率随时间变化,那么再马尔可夫链的作用下,它会趋向于某种平稳分布,时间越长,分布越平稳。如图2所示,当你向一滴水中滴入一滴颜料时,无论你滴在什么位置,只要时间足够长,最终颜料都会均匀的分布在水溶液中。这也就是扩散模型的前向过程。

图2:颜料分子在水溶液中的扩散过程

如果我们能够在扩散的过程颜料分子的位置、移动速度、方向等移动属性。那么也可以根据正向过程的保存的移动属性从一杯被溶解了颜料的水中反推颜料的滴入位置。这边是扩散模型的反向过程。记录移动属性的快照便是我们要训练的模型。

2. 扩散模型

在这一部分我们将集中介绍扩散模型的数学原理以及推导的几个重要性质,因为推导过程涉及大量的数学知识但是对理解扩散模型本身思想并无太大帮助,所以这里我会省去推导的过程而直接给出结论。但是我也会给出推导过程的出处,对其中的推导过程比较感兴趣的请自行查看。

2.1 计算原理

扩散模型简单的讲就是通过神经网络学习从纯噪声数据逐渐对数据进行去噪的过程,它包含两个步骤,如图3:

图3:DDPM的前向加噪和后向去噪过程

2.1.1 前向过程

2.1.2 后向过程

2.1.3 目标函数

那么问题来了,我们究竟使用什么样的优化目标才能比较好的预测高斯噪声的分布呢?一个比较复杂的方式是使用变分自编码器的最大化证据下界(Evidence Lower Bound, ELBO)的思想来推导,如式(6),推导详细过程见论文[11]的式(47)到式(58),这里主要用到了贝叶斯定理和琴生不等式。

式(6)的推导细节并不重要,我们需要重点关注的是它的最终等式的三个组成部分,下面我们分别介绍它们:

图4:扩散模型的去噪匹配项在每一步都要拟合噪音的真实后验分布和估计分布

真实后验分布可以使用贝叶斯定理进行推导,最终结果如式(8),推导过程见论文[11]的式(71)到式(84)。

\(p{\boldsymbol{\theta}}\left(\boldsymbol{x}{t-1} \mid \boldsymbol{x}t\right) = \mathcal N(\boldsymbol x{t-1}; \mu\theta(\boldsymbol x\_t, t), \Sigma\_q(t)) \tag9\)

2.1.4 模型训练

虽然上面我们介绍了很多内容,并给出了大量公式,但得益于推导出的几个重要性质,扩散模型的训练并不复杂,它的训练伪代码见算法1。

2.1.5 样本生成

2.2 算法实现2.2.1模型结构

DDPM在预测施加的噪声时,它的输入是施加噪声之后的图像,预测内容是和输入图像相同尺寸的噪声,所以它可以看做一个Img2Img的任务。DDPM选择了U-Net[9]作为噪声预测的模型结构。U-Net是一个U形的网络结构,它由编码器,解码器以及编码器和解码器之间的跨层连接(残差连接)组成。其中编码器将图像降采样成一个特征,解码器将这个特征上采样为目标噪声,跨层连接用于拼接编码器和解码器之间的特征。

图5:U-Net的网络结构

下面我们介绍DDPM的模型结构的重要组件。首先在U-Net的卷积部分,DDPM使用了宽残差网络(Wide Residual Network,WRN)[12]作为核心结构,WRN是一个比标准残差网络层数更少,但是通道数

相关推荐: